-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Bamba Model #10909
base: main
Are you sure you want to change the base?
Add Bamba Model #10909
Conversation
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Hi @fabianlim, thanks for the PR! It's really great to see progress being made on state-space models, especially for me as I unfortunately haven't been able to prioritize support for Mamba2 I'm happy to shepherd this PR and discuss any questions you have, especially to support chunked prefill. If you haven't already, can you join the developer slack for quicker discussion? (https://communityinviter.com/apps/vllm-dev/join-vllm-developers-slack) |
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@tlrmchlsmth I cleaned up the PR quite abit, perhaps it might be a good time to get some early eyes. The chunked prefill implementation is incomplete ATM, as we discussed offline. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass, just a few comments. At a high level it looks good.
Will you add a test for tensor parallelism?
# will be ch | ||
MODELS = ["ibm-fms/Bamba-9.8b-1.8T-hf"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment trails off, but will there be a small test model available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@raghukiran1224 any plans for a small test model? I think since we do outputs comparison it is not that good to just have a randomly initialised small model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fabianlim @tlrmchlsmth would it be ok to test with a random model or would you rather have a tiny model (say 200M or so) to test with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A tiny model with nonrandom weights would be much better!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw is there any update on this?
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
@tlrmchlsmth i have addressed most of your comments now, not rebasing yet, waiting for you to look first. But I realized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fabianlim At a high level, the changes look good, and the PR looks good overall. I'll do a more thorough review once it's unmarked as draft.
Could you add unit tests for the added kernels in layers/mamba/ops
?
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
lora_config = vllm_config.lora_config | ||
|
||
self.config = config | ||
self.padding_idx = config.pad_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this one used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry good catch I will remove it
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Note to self, some of the testing API has changed due to this PR #10353 |
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking pretty good, let's get this landed now that 4.48.2 is out!
vllm/model_executor/models/mamba2.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's land mamba 2 in #9292
requirements-common.txt
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you merge in latest main? We've already landed this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ya actually yest i reverted this file and took it from latest main, but somehow the diff shows up in github. The version on the left shown by github is actually old
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok if i merge in latest main it seems fine..
requirements-test.txt
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, the version on the left is old. the right is from latest main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's land these changes as part of #9292
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated | ||
@CustomOp.register("mixer2_gated_rms_norm") | ||
class Mixer2RMSNormGated(CustomOp): | ||
|
||
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): | ||
super().__init__() | ||
self.tp_size = get_tensor_model_parallel_world_size() | ||
self.tp_rank = get_tensor_model_parallel_rank() | ||
self.full_hidden_size = full_hidden_size | ||
self.group_size = full_hidden_size // full_n_groups | ||
self.per_rank_hidden_size = full_hidden_size // self.tp_size | ||
self.n_groups = full_hidden_size // self.group_size | ||
|
||
self.variance_epsilon = eps | ||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) | ||
set_weight_attrs(self.weight, | ||
{"weight_loader": sharded_weight_loader(0)}) | ||
assert self.full_hidden_size % self.tp_size== 0,\ | ||
"Tensor parallel world size must divide hidden size." | ||
|
||
def forward_native( | ||
self, | ||
x: torch.Tensor, | ||
gate: torch.Tensor, | ||
): | ||
# Three tensor-parallel cases: | ||
# 1. n_groups is 1 | ||
# In this case we parallelize along the reduction dim. | ||
# Each rank computes a local sum of squares followed by AllReduce | ||
# 2. tp_size divides n_groups | ||
# Each rank only reduces within its local group(s). | ||
# No collective ops necessary. | ||
# 3. The general case can be pretty complicated so we AllGather | ||
# the input and then redundantly compute the RMSNorm. | ||
input_dtype = x.dtype | ||
x = x * nn.functional.silu(gate.to(torch.float32)) | ||
|
||
if self.n_groups == 1: | ||
if self.tp_size > 1: | ||
# Compute local sum and then reduce to obtain global sum | ||
local_sums = x.pow(2).sum(dim=-1, keepdim=True) | ||
global_sums = tensor_model_parallel_all_reduce(local_sums) | ||
# Calculate the variance | ||
count = self.tp_size * x.shape[-1] | ||
variance = (global_sums / count) | ||
|
||
else: | ||
variance = x.pow(2).mean(-1, keepdim=True) | ||
x = x * torch.rsqrt(variance + self.variance_epsilon) | ||
else: | ||
redundant_tp: bool = self.n_groups % self.tp_size != 0 | ||
if redundant_tp: | ||
# To handle the general case, redundantly apply the variance | ||
x = tensor_model_parallel_all_gather(x, -1) | ||
|
||
*prefix_dims, hidden_dim = x.shape | ||
group_count = hidden_dim // self.group_size | ||
x_grouped = x.view(*prefix_dims, group_count, self.group_size) | ||
variance = x_grouped.pow(2).mean(-1, keepdim=True) | ||
x_grouped = x_grouped * torch.rsqrt(variance + | ||
self.variance_epsilon) | ||
x = x_grouped.view(*prefix_dims, hidden_dim) | ||
|
||
if redundant_tp: | ||
start = self.per_rank_hidden_size * self.tp_rank | ||
end = start + self.per_rank_hidden_size | ||
x = x[..., start:end] | ||
|
||
return self.weight * x.to(input_dtype) | ||
|
||
def forward_cuda( | ||
self, | ||
x: torch.Tensor, | ||
gate: torch.Tensor, | ||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
|
||
if self.tp_size > 1 or self.n_groups != 1: | ||
return self.forward_native(x, gate) | ||
|
||
from vllm import _custom_ops as ops | ||
|
||
# cast x and gate to float32 before silu | ||
out = torch.empty_like(x) | ||
y = x * nn.functional.silu(gate.to(torch.float32)) | ||
ops.rms_norm( | ||
out, | ||
y.to(x.dtype), | ||
self.weight.data, | ||
self.variance_epsilon, | ||
) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should get a unit test in place for this, especially the various tensor parallel cases. @fabianlim do you have bandwidth to do that? Otherwise I can do it in either in #9292 or a separate PR. I do feel pretty good about correctness here, having manually tested various cases thoroughly enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just unit testing only the Mixer2RMSNormGated
? if so how would you setup the test? conftest
only has runners for the whole model.
@@ -69,6 +70,7 @@ | |||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), | |||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"), | |||
"FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), | |||
"Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's land this in #9292
To debug the pre-commit issue locally you may need to run:
|
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@tlrmchlsmth thank you so much for your comments. I have fixed the |
This is the companion PR to an huggingface PR for adding
Bamba
, which is a hybrid mamba2 architecture with SwiGLU. The checkpoints are jointly trained by IBM, Princeton, and UIUC.In this PR we have:
bamba
model inference architecture, which we would like acknowledge thejamba
team for referencing their implementation, whereby we modified to support full attention layers with RoPE and mamba v2.Currently we have a partial solution, which works only when the cont batch boundaries line up with the chunked boundaries.This is now completely fixed.vllm/model_executor/layers/mamba/ops
. Only thefwd
kernels are extracted. Some modifications and fixes are made.tests/models/decoder_only/language/test_bamba.py
with an initialibm-fms/Bamba-9.8b-1.8T-hf
. This is practically identical totest_mamba.py
, only chunked prefill tests are disabled as it is currently not supported.Currently only
FlashAttention
backend is supported, as we check fields likecontext_lens_tensor
. Have not yet investigated other backends.We would like to also acknowledge the draft codestral mamba PR from @tlrmchlsmth, which we also referenced the mixer.
Hope to discuss the following with the maintainers
do we have to remove all theyes we shouldbwd
kernels?sin_cos
cache to cover the sequence length, if it is longer thanmax_sequence_len
.This differs for other current models (e.g., llama). How can we better support long sequence lengths?we should keep this consistent with other models, so we propose to allow thesin_cos
cache extension only whenVLLM_ALLOW_LONG_MAX_MODEL_LEN
is specified.have some ideas to support chunked pre-fill, but will appreciate some discussion with the maintainers on how to proceed.working on changing the kernels to support chunked prefill.since the mixer2 is simplified from mamba, should we rename it?we can keep it as is, but we should document the differences frommamba_ssm
cc: @ani300, @raghukiran1224, @cyang49, @njhill